{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "d92GsFoL0QPI"
   },
   "outputs": [],
   "source": [
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 124
    },
    "colab_type": "code",
    "id": "Hul0UX8W0fen",
    "outputId": "885c0627-b303-4962-db91-c21726bcb909"
   },
   "outputs": [],
   "source": [
    "# from google.colab import drive \n",
    "# drive.mount('/content/gdrive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 226
    },
    "colab_type": "code",
    "id": "khAd4o0M0f7Q",
    "outputId": "3d77b2a8-bf9b-4770-daec-4c4bcbbd9ff5"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>protocol_type</th>\n",
       "      <th>service</th>\n",
       "      <th>flag</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_srv_count</th>\n",
       "      <th>dst_host_same_srv_rate</th>\n",
       "      <th>dst_host_diff_srv_rate</th>\n",
       "      <th>dst_host_same_src_port_rate</th>\n",
       "      <th>dst_host_srv_diff_host_rate</th>\n",
       "      <th>dst_host_serror_rate</th>\n",
       "      <th>dst_host_srv_serror_rate</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>dst_host_srv_rerror_rate</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>181</td>\n",
       "      <td>5450</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>9</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.11</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>normal.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>239</td>\n",
       "      <td>486</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>19</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>normal.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>235</td>\n",
       "      <td>1337</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>29</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>normal.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>219</td>\n",
       "      <td>1337</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>39</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>normal.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>217</td>\n",
       "      <td>2032</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>49</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>normal.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 42 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   duration protocol_type service flag  src_bytes  dst_bytes  land  \\\n",
       "0         0           tcp    http   SF        181       5450     0   \n",
       "1         0           tcp    http   SF        239        486     0   \n",
       "2         0           tcp    http   SF        235       1337     0   \n",
       "3         0           tcp    http   SF        219       1337     0   \n",
       "4         0           tcp    http   SF        217       2032     0   \n",
       "\n",
       "   wrong_fragment  urgent  hot  ...  dst_host_srv_count  \\\n",
       "0               0       0    0  ...                   9   \n",
       "1               0       0    0  ...                  19   \n",
       "2               0       0    0  ...                  29   \n",
       "3               0       0    0  ...                  39   \n",
       "4               0       0    0  ...                  49   \n",
       "\n",
       "   dst_host_same_srv_rate  dst_host_diff_srv_rate  \\\n",
       "0                     1.0                     0.0   \n",
       "1                     1.0                     0.0   \n",
       "2                     1.0                     0.0   \n",
       "3                     1.0                     0.0   \n",
       "4                     1.0                     0.0   \n",
       "\n",
       "   dst_host_same_src_port_rate  dst_host_srv_diff_host_rate  \\\n",
       "0                         0.11                          0.0   \n",
       "1                         0.05                          0.0   \n",
       "2                         0.03                          0.0   \n",
       "3                         0.03                          0.0   \n",
       "4                         0.02                          0.0   \n",
       "\n",
       "   dst_host_serror_rate  dst_host_srv_serror_rate  dst_host_rerror_rate  \\\n",
       "0                   0.0                       0.0                   0.0   \n",
       "1                   0.0                       0.0                   0.0   \n",
       "2                   0.0                       0.0                   0.0   \n",
       "3                   0.0                       0.0                   0.0   \n",
       "4                   0.0                       0.0                   0.0   \n",
       "\n",
       "   dst_host_srv_rerror_rate    label  \n",
       "0                       0.0  normal.  \n",
       "1                       0.0  normal.  \n",
       "2                       0.0  normal.  \n",
       "3                       0.0  normal.  \n",
       "4                       0.0  normal.  \n",
       "\n",
       "[5 rows x 42 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Feature names from the file kddcup.names file to be used as cols heading\n",
    "col_names = [\"duration\",\"protocol_type\",\"service\",\"flag\",\"src_bytes\",\n",
    "    \"dst_bytes\",\"land\",\"wrong_fragment\",\"urgent\",\"hot\",\"num_failed_logins\",\n",
    "    \"logged_in\",\"num_compromised\",\"root_shell\",\"su_attempted\",\"num_root\",\n",
    "    \"num_file_creations\",\"num_shells\",\"num_access_files\",\"num_outbound_cmds\",\n",
    "    \"is_host_login\",\"is_guest_login\",\"count\",\"srv_count\",\"serror_rate\",\n",
    "    \"srv_serror_rate\",\"rerror_rate\",\"srv_rerror_rate\",\"same_srv_rate\",\n",
    "    \"diff_srv_rate\",\"srv_diff_host_rate\",\"dst_host_count\",\"dst_host_srv_count\",\n",
    "    \"dst_host_same_srv_rate\",\"dst_host_diff_srv_rate\",\"dst_host_same_src_port_rate\",\n",
    "    \"dst_host_srv_diff_host_rate\",\"dst_host_serror_rate\",\"dst_host_srv_serror_rate\",\n",
    "    \"dst_host_rerror_rate\",\"dst_host_srv_rerror_rate\",\"label\"]\n",
    "\n",
    "df = pd.read_csv(\"kddcup.data_10_percent_corrected\", header=None, names = col_names)\n",
    "\n",
    "# df=pd.read_csv('/content/gdrive/My Drive/data/kddcup.data_10_percent_corrected')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "HBeT3WVV0nww",
    "outputId": "d3b93147-b9a9-4e8d-8127-d436f87d3966"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(494021, 42)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['normal.', 'buffer_overflow.', 'loadmodule.', 'perl.', 'neptune.',\n",
       "       'smurf.', 'guess_passwd.', 'pod.', 'teardrop.', 'portsweep.',\n",
       "       'ipsweep.', 'land.', 'ftp_write.', 'back.', 'imap.', 'satan.',\n",
       "       'phf.', 'nmap.', 'multihop.', 'warezmaster.', 'warezclient.',\n",
       "       'spy.', 'rootkit.'], dtype=object)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['label'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 121
    },
    "colab_type": "code",
    "id": "Wra_AdYd1K4o",
    "outputId": "c41f8c80-6fc1-42ef-b401-cbb594448912"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0    0\n",
      "1    0\n",
      "2    0\n",
      "3    0\n",
      "4    0\n",
      "Name: label, dtype: object\n"
     ]
    }
   ],
   "source": [
    "newlabeldf=df['label'].replace({ 'normal.' : 0, 'neptune.' : 1 ,'back.': 1, 'land.': 1, 'pod.': 1, 'smurf.': 1, 'teardrop.': 1,'mailbomb.': 1, 'apache2.': 1, 'processtable.': 1, 'udpstorm.': 1, 'worm.': 1,\n",
    "                           'ipsweep.' : 2,'nmap.' : 2,'portsweep.' : 2,'satan.' : 2,'mscan.' : 2,'saint.' : 2\n",
    "                           ,'ftp_write.': 3,'guess_passwd.': 3,'imap.': 3,'multihop.': 3,'phf.': 3,'spy.': 3,'warezclient.': 3,'warezmaster.': 3,'sendmail.': 3,'named.': 3,'snmpgetattack.': 3,'snmpguess.': 3,'xlock.': 3,'xsnoop.': 3,'httptunnel.': 3,\n",
    "                           'buffer_overflow.': 4,'loadmodule.': 4,'perl.': 4,'rootkit.': 4,'ps.': 4,'sqlattack.': 4,'xterm.': 4})\n",
    "df['label'] = newlabeldf\n",
    "#df['target'] = newlabeldf\n",
    "print(df['label'].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ediw-aU11RLX",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FqW5YY3C1YST"
   },
   "outputs": [],
   "source": [
    "\n",
    "from sklearn.preprocessing import LabelEncoder,OneHotEncoder\n",
    "le = LabelEncoder()\n",
    "df['protocol_type'] = le.fit_transform(df['protocol_type'])\n",
    "df['service']= le.fit_transform(df['service'])\n",
    "df['flag'] = le.fit_transform(df['flag'])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "U5Q3f5St1dJy"
   },
   "outputs": [],
   "source": [
    "X = df.iloc[:,:41]\n",
    "y = df.iloc[:,-1].astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 226
    },
    "colab_type": "code",
    "id": "H-2bq4BE1gWJ",
    "outputId": "9cb39821-3c60-40e0-a513-3b8de52df58a"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>protocol_type</th>\n",
       "      <th>service</th>\n",
       "      <th>flag</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_count</th>\n",
       "      <th>dst_host_srv_count</th>\n",
       "      <th>dst_host_same_srv_rate</th>\n",
       "      <th>dst_host_diff_srv_rate</th>\n",
       "      <th>dst_host_same_src_port_rate</th>\n",
       "      <th>dst_host_srv_diff_host_rate</th>\n",
       "      <th>dst_host_serror_rate</th>\n",
       "      <th>dst_host_srv_serror_rate</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>dst_host_srv_rerror_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>22</td>\n",
       "      <td>9</td>\n",
       "      <td>181</td>\n",
       "      <td>5450</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>9</td>\n",
       "      <td>9</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.11</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>22</td>\n",
       "      <td>9</td>\n",
       "      <td>239</td>\n",
       "      <td>486</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>19</td>\n",
       "      <td>19</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>22</td>\n",
       "      <td>9</td>\n",
       "      <td>235</td>\n",
       "      <td>1337</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>29</td>\n",
       "      <td>29</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>22</td>\n",
       "      <td>9</td>\n",
       "      <td>219</td>\n",
       "      <td>1337</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>39</td>\n",
       "      <td>39</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>22</td>\n",
       "      <td>9</td>\n",
       "      <td>217</td>\n",
       "      <td>2032</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>49</td>\n",
       "      <td>49</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 41 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   duration  protocol_type  service  flag  src_bytes  dst_bytes  land  \\\n",
       "0         0              1       22     9        181       5450     0   \n",
       "1         0              1       22     9        239        486     0   \n",
       "2         0              1       22     9        235       1337     0   \n",
       "3         0              1       22     9        219       1337     0   \n",
       "4         0              1       22     9        217       2032     0   \n",
       "\n",
       "   wrong_fragment  urgent  hot  ...  dst_host_count  dst_host_srv_count  \\\n",
       "0               0       0    0  ...               9                   9   \n",
       "1               0       0    0  ...              19                  19   \n",
       "2               0       0    0  ...              29                  29   \n",
       "3               0       0    0  ...              39                  39   \n",
       "4               0       0    0  ...              49                  49   \n",
       "\n",
       "   dst_host_same_srv_rate  dst_host_diff_srv_rate  \\\n",
       "0                     1.0                     0.0   \n",
       "1                     1.0                     0.0   \n",
       "2                     1.0                     0.0   \n",
       "3                     1.0                     0.0   \n",
       "4                     1.0                     0.0   \n",
       "\n",
       "   dst_host_same_src_port_rate  dst_host_srv_diff_host_rate  \\\n",
       "0                         0.11                          0.0   \n",
       "1                         0.05                          0.0   \n",
       "2                         0.03                          0.0   \n",
       "3                         0.03                          0.0   \n",
       "4                         0.02                          0.0   \n",
       "\n",
       "   dst_host_serror_rate  dst_host_srv_serror_rate  dst_host_rerror_rate  \\\n",
       "0                   0.0                       0.0                   0.0   \n",
       "1                   0.0                       0.0                   0.0   \n",
       "2                   0.0                       0.0                   0.0   \n",
       "3                   0.0                       0.0                   0.0   \n",
       "4                   0.0                       0.0                   0.0   \n",
       "\n",
       "   dst_host_srv_rerror_rate  \n",
       "0                       0.0  \n",
       "1                       0.0  \n",
       "2                       0.0  \n",
       "3                       0.0  \n",
       "4                       0.0  \n",
       "\n",
       "[5 rows x 41 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EnB5tkCC1h-y"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 34,test_size = 0.3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "c6FtruIN1k9a"
   },
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler\n",
    "scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "X_train_scaled = scaler.fit_transform(X_train)\n",
    "X_test_scaled = scaler.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "colab_type": "code",
    "id": "nSwKjjSj1mpA",
    "outputId": "2f62a6d4-de3a-4b24-beae-3c400d169396"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestClassifier(bootstrap=True, class_weight='balanced',\n",
       "                       criterion='gini', max_depth=10, max_features='auto',\n",
       "                       max_leaf_nodes=None, min_impurity_decrease=0.0,\n",
       "                       min_impurity_split=None, min_samples_leaf=1,\n",
       "                       min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
       "                       n_estimators=100, n_jobs=None, oob_score=False,\n",
       "                       random_state=0, verbose=0, warm_start=False)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "model1 = RandomForestClassifier(n_estimators=100, max_depth=10,random_state=0,class_weight='balanced')\n",
    "model1.fit(X_train_scaled,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "atDZW0ym1syj"
   },
   "outputs": [],
   "source": [
    "predict = model1.predict(X_test_scaled)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "K8TdzvZ52aQy",
    "outputId": "3d778600-85a9-4663-ccbf-0109729bdd32"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9993927412335449"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model1.score(X_test_scaled,y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "OUVNaSp124d8",
    "outputId": "53a8bf20-7543-461f-c2c5-b9bd02b5dcc0"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.999635642281689"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model1.score(X_train_scaled,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 104
    },
    "colab_type": "code",
    "id": "n4BtTqXu28nC",
    "outputId": "25306594-09cb-4154-f9fd-bbe9a4ed9d6b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 29073      1      0     54      2]\n",
      " [     9 117470      0      0      1]\n",
      " [    13      0   1254      0      0]\n",
      " [     7      1      0    305      0]\n",
      " [     2      0      0      0     15]]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix,recall_score,precision_score\n",
    "print((confusion_matrix(y_test,predict)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 225
    },
    "colab_type": "code",
    "id": "7WoU9Ty73AMP",
    "outputId": "94302332-bace-46ed-9294-44ea8cf92d55"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00     29130\n",
      "           1       1.00      1.00      1.00    117480\n",
      "           2       1.00      0.99      0.99      1267\n",
      "           3       0.85      0.97      0.91       313\n",
      "           4       0.83      0.88      0.86        17\n",
      "\n",
      "    accuracy                           1.00    148207\n",
      "   macro avg       0.94      0.97      0.95    148207\n",
      "weighted avg       1.00      1.00      1.00    148207\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(y_test, predict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jJUis4ku3CaW"
   },
   "source": [
    "## Predicting actual 10% test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df_test = pd.read_csv(\"kddcup.newtestdata_10_percent_unlabeled\", header=None, names = col_names[:-1])\n",
    "# df_test['protocol_type'] = le.fit_transform(df_test['protocol_type'])\n",
    "# df_test['service']= le.fit_transform(df_test['service'])\n",
    "# df_test['flag'] = le.fit_transform(df_test['flag'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>protocol_type</th>\n",
       "      <th>service</th>\n",
       "      <th>flag</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_count</th>\n",
       "      <th>dst_host_srv_count</th>\n",
       "      <th>dst_host_same_srv_rate</th>\n",
       "      <th>dst_host_diff_srv_rate</th>\n",
       "      <th>dst_host_same_src_port_rate</th>\n",
       "      <th>dst_host_srv_diff_host_rate</th>\n",
       "      <th>dst_host_serror_rate</th>\n",
       "      <th>dst_host_srv_serror_rate</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>dst_host_srv_rerror_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>smtp</td>\n",
       "      <td>SF</td>\n",
       "      <td>829</td>\n",
       "      <td>327</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>8</td>\n",
       "      <td>113</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>udp</td>\n",
       "      <td>private</td>\n",
       "      <td>SF</td>\n",
       "      <td>105</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>255</td>\n",
       "      <td>253</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>udp</td>\n",
       "      <td>private</td>\n",
       "      <td>SF</td>\n",
       "      <td>105</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>255</td>\n",
       "      <td>254</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>udp</td>\n",
       "      <td>private</td>\n",
       "      <td>SF</td>\n",
       "      <td>105</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>255</td>\n",
       "      <td>253</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>ftp_data</td>\n",
       "      <td>SF</td>\n",
       "      <td>19</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>226</td>\n",
       "      <td>46</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 41 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   duration protocol_type   service flag  src_bytes  dst_bytes  land  \\\n",
       "0         0           tcp      smtp   SF        829        327     0   \n",
       "1         0           udp   private   SF        105        146     0   \n",
       "2         0           udp   private   SF        105        146     0   \n",
       "3         0           udp   private   SF        105        146     0   \n",
       "4         0           tcp  ftp_data   SF         19          0     0   \n",
       "\n",
       "   wrong_fragment  urgent  hot  ...  dst_host_count  dst_host_srv_count  \\\n",
       "0               0       0    0  ...               8                 113   \n",
       "1               0       0    0  ...             255                 253   \n",
       "2               0       0    0  ...             255                 254   \n",
       "3               0       0    0  ...             255                 253   \n",
       "4               0       0    0  ...             226                  46   \n",
       "\n",
       "   dst_host_same_srv_rate  dst_host_diff_srv_rate  \\\n",
       "0                    0.88                    0.25   \n",
       "1                    0.99                    0.01   \n",
       "2                    1.00                    0.01   \n",
       "3                    0.99                    0.01   \n",
       "4                    0.19                    0.03   \n",
       "\n",
       "   dst_host_same_src_port_rate  dst_host_srv_diff_host_rate  \\\n",
       "0                         0.12                         0.02   \n",
       "1                         0.00                         0.00   \n",
       "2                         0.01                         0.00   \n",
       "3                         0.00                         0.00   \n",
       "4                         0.19                         0.04   \n",
       "\n",
       "   dst_host_serror_rate  dst_host_srv_serror_rate  dst_host_rerror_rate  \\\n",
       "0                   0.0                       0.0                   0.0   \n",
       "1                   0.0                       0.0                   0.0   \n",
       "2                   0.0                       0.0                   0.0   \n",
       "3                   0.0                       0.0                   0.0   \n",
       "4                   0.0                       0.0                   0.0   \n",
       "\n",
       "   dst_host_srv_rerror_rate  \n",
       "0                       0.0  \n",
       "1                       0.0  \n",
       "2                       0.0  \n",
       "3                       0.0  \n",
       "4                       0.0  \n",
       "\n",
       "[5 rows x 41 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test['protocol_type'] = le.fit_transform(df_test['protocol_type'])\n",
    "df_test['service']= le.fit_transform(df_test['service'])\n",
    "df_test['flag'] = le.fit_transform(df_test['flag'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>protocol_type</th>\n",
       "      <th>service</th>\n",
       "      <th>flag</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_count</th>\n",
       "      <th>dst_host_srv_count</th>\n",
       "      <th>dst_host_same_srv_rate</th>\n",
       "      <th>dst_host_diff_srv_rate</th>\n",
       "      <th>dst_host_same_src_port_rate</th>\n",
       "      <th>dst_host_srv_diff_host_rate</th>\n",
       "      <th>dst_host_serror_rate</th>\n",
       "      <th>dst_host_srv_serror_rate</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>dst_host_srv_rerror_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>50</td>\n",
       "      <td>8</td>\n",
       "      <td>829</td>\n",
       "      <td>327</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>8</td>\n",
       "      <td>113</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>46</td>\n",
       "      <td>8</td>\n",
       "      <td>105</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>255</td>\n",
       "      <td>253</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>46</td>\n",
       "      <td>8</td>\n",
       "      <td>105</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>255</td>\n",
       "      <td>254</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>46</td>\n",
       "      <td>8</td>\n",
       "      <td>105</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>255</td>\n",
       "      <td>253</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>19</td>\n",
       "      <td>8</td>\n",
       "      <td>19</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>226</td>\n",
       "      <td>46</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 41 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   duration  protocol_type  service  flag  src_bytes  dst_bytes  land  \\\n",
       "0         0              1       50     8        829        327     0   \n",
       "1         0              2       46     8        105        146     0   \n",
       "2         0              2       46     8        105        146     0   \n",
       "3         0              2       46     8        105        146     0   \n",
       "4         0              1       19     8         19          0     0   \n",
       "\n",
       "   wrong_fragment  urgent  hot  ...  dst_host_count  dst_host_srv_count  \\\n",
       "0               0       0    0  ...               8                 113   \n",
       "1               0       0    0  ...             255                 253   \n",
       "2               0       0    0  ...             255                 254   \n",
       "3               0       0    0  ...             255                 253   \n",
       "4               0       0    0  ...             226                  46   \n",
       "\n",
       "   dst_host_same_srv_rate  dst_host_diff_srv_rate  \\\n",
       "0                    0.88                    0.25   \n",
       "1                    0.99                    0.01   \n",
       "2                    1.00                    0.01   \n",
       "3                    0.99                    0.01   \n",
       "4                    0.19                    0.03   \n",
       "\n",
       "   dst_host_same_src_port_rate  dst_host_srv_diff_host_rate  \\\n",
       "0                         0.12                         0.02   \n",
       "1                         0.00                         0.00   \n",
       "2                         0.01                         0.00   \n",
       "3                         0.00                         0.00   \n",
       "4                         0.19                         0.04   \n",
       "\n",
       "   dst_host_serror_rate  dst_host_srv_serror_rate  dst_host_rerror_rate  \\\n",
       "0                   0.0                       0.0                   0.0   \n",
       "1                   0.0                       0.0                   0.0   \n",
       "2                   0.0                       0.0                   0.0   \n",
       "3                   0.0                       0.0                   0.0   \n",
       "4                   0.0                       0.0                   0.0   \n",
       "\n",
       "   dst_host_srv_rerror_rate  \n",
       "0                       0.0  \n",
       "1                       0.0  \n",
       "2                       0.0  \n",
       "3                       0.0  \n",
       "4                       0.0  \n",
       "\n",
       "[5 rows x 41 columns]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted = model1.predict(df_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Found input variables with inconsistent numbers of samples: [148207, 311079]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-29-fbca30c0c465>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassification_report\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/home/rahulr/snap/jupyter/common/lib/python3.7/site-packages/sklearn/metrics/classification.py\u001b[0m in \u001b[0;36mclassification_report\u001b[0;34m(y_true, y_pred, labels, target_names, sample_weight, digits, output_dict)\u001b[0m\n\u001b[1;32m   1850\u001b[0m     \"\"\"\n\u001b[1;32m   1851\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1852\u001b[0;31m     \u001b[0my_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_targets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1853\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1854\u001b[0m     \u001b[0mlabels_given\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/rahulr/snap/jupyter/common/lib/python3.7/site-packages/sklearn/metrics/classification.py\u001b[0m in \u001b[0;36m_check_targets\u001b[0;34m(y_true, y_pred)\u001b[0m\n\u001b[1;32m     69\u001b[0m     \u001b[0my_pred\u001b[0m \u001b[0;34m:\u001b[0m \u001b[0marray\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mindicator\u001b[0m \u001b[0mmatrix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m     \"\"\"\n\u001b[0;32m---> 71\u001b[0;31m     \u001b[0mcheck_consistent_length\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     72\u001b[0m     \u001b[0mtype_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtype_of_target\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m     \u001b[0mtype_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtype_of_target\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/rahulr/snap/jupyter/common/lib/python3.7/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_consistent_length\u001b[0;34m(*arrays)\u001b[0m\n\u001b[1;32m    203\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muniques\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    204\u001b[0m         raise ValueError(\"Found input variables with inconsistent numbers of\"\n\u001b[0;32m--> 205\u001b[0;31m                          \" samples: %r\" % [int(l) for l in lengths])\n\u001b[0m\u001b[1;32m    206\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    207\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Found input variables with inconsistent numbers of samples: [148207, 311079]"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, predicted))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "name": "kaggle_code.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
